import socket
import os
import time
import traceback
import numpy as np
from colorama import init
init()


def print红(*kw):
    print("\033[0;31m",*kw,"\033[0m")
def print绿(*kw):
    print("\033[0;32m",*kw,"\033[0m")
def print黄(*kw):
    print("\033[0;33m",*kw,"\033[0m")
def print蓝(*kw):
    print("\033[0;34m",*kw,"\033[0m")
def print紫(*kw):
    print("\033[0;35m",*kw,"\033[0m")
def print靛(*kw):
    print("\033[0;36m",*kw,"\033[0m")

'''
寻找一个没有被占用的文本序号作为起点
'''
def find_free_index(path):
    if not os.path.exists(path):
        os.makedirs(path)
    t = 0
    while True:
        if      os.path.exists(path+'/mCOMv5_buffer_%d.txt'%t) \
             or os.path.exists(path+'/mCOMv5_buffer_%d____starting_session.txt'%t):
            t += 1
        else:
            return t

class mCOMv5():
    # 这段小程序的设计原则是 在任意情况下 不干扰主线程序的运行！
    def __init__(self, ip = None, port = None, path = None, digit=8, rapid_flush = True):
        # digit 定义了传输中每个数字（尤其是浮点数）的有效数字位数，
        # digit 默认8，可选4,6，越小程序负担越轻
        # rapid_flush 当数据流不大时，及时倾倒文件缓存内容
        self.port = port
        self.dst = (ip, port)
        self.path = path
        self.current_buffer_index = find_free_index(self.path)
        self.starting_file = self.path+'/mCOMv5_buffer_%d____starting_session.txt'%(self.current_buffer_index)
        self.current_file_handle = open(self.starting_file,'wb+')
        self.file_lines_cnt = 0
        self.file_max_lines = 30000 # limit file lines to avoid a very fucking large file
        self.digit = digit
        self.rapid_flush = rapid_flush
        if self.port is not None:
            self.socketx = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            str_tmp = ">>@"+ str(self.current_buffer_index) + "@" + os.path.abspath(self.starting_file)
            b_tmp = bytes(str_tmp, encoding = 'utf8')
            self.socketx.sendto(b_tmp, self.dst)
        print蓝('**************mCOMv5 service initialized**************')
        print蓝('use MATLAB to open live log file at:' + self.starting_file)





    '''
        mCOMv5 core function: send out/write raw bytes
    '''
    def send(self, data):
        # step 1: add to file
        self.file_lines_cnt += 1
        self.current_file_handle.write(data)
        if self.rapid_flush: self.current_file_handle.flush()
        
        # step 2: check whether the file is too large, if so move on to next file.
        if self.file_lines_cnt > self.file_max_lines:
            end_file_flag = (b'><EndFileFlag\n')
            self.current_file_handle.write(end_file_flag)
            self.current_file_handle.close()
            self.current_buffer_index += 1
            self.current_file_handle = open((self.path+'/mCOMv5_buffer_%d.txt'%self.current_buffer_index),'wb+')
            self.file_lines_cnt = 0

        # # step 3: UDP send directly
        # if self.port is not None:
        #     self.socketx.sendto(data,self.dst)





    def rec_init(self,color = 'k'):
        str_tmp = '>>rec_init(\'%s\')\n'%color
        b_tmp = bytes(str_tmp, encoding = 'utf8')
        self.send(b_tmp)

    def rec_show(self):
        self.send(b'>>rec_show\n')

    def rec_end(self):
        self.send(b'>>rec_end\n')

    def rec_save(self):
        self.send(b'>>rec_save\n')

    def rec_end_hold(self):
        self.send(b'>>rec_end_hold\n')
    
    def rec_clear(self,name):
        str_tmp = '>>rec_clear("%s")\n'%(name)
        str_tmp = bytes(str_tmp, encoding = 'utf8')
        self.send(str_tmp)


    '''
        mCOMv5 core function: draw line, by assigning points one by one, 
        for example 
            uc.rec(100,'live loss valueX')
            uc.rec(0.1,'entropy agentX')
            uc.rec(99, 'live loss valueX')
            uc.rec(0.3,'entropy agentX')
            uc.rec(88, 'live loss valueX')
            uc.rec(0.5,'entropy agentX')
    '''
    def rec(self,value,name):
        value = float(value)
        
        if self.digit == 8   : str_tmp = '>>rec(%.8e,"%s")\n'%(value,name)
        elif self.digit == 6 : str_tmp = '>>rec(%.6e,"%s")\n'%(value,name)
        elif self.digit == 4 : str_tmp = '>>rec(%.4e,"%s")\n'%(value,name)

        str_tmp = bytes(str_tmp, encoding = 'utf8')
        self.send(str_tmp)

    def 发送虚幻4数据流(self,x,y,z,pitch,yaw,roll):
        x = float(x)
        y = float(y)
        z = float(z)
        pitch = float(pitch)
        yaw = float(yaw)
        roll = float(roll)
        str_tmp = 'UE4>>(\"agent#1\",%.6e,%.6e,%.6e,%.6e,%.6e,%.6e)\n'%(x,y,z,pitch,yaw,roll)
        str_tmp = bytes(str_tmp, encoding = 'utf8')
        self.send(str_tmp)

    def 发送虚幻4数据流_多智能体(self,x_,y_,z_,pitch_,yaw_,roll_):
        str_list = ['UE4>>']
        for x,y,z,pitch,yaw,roll in zip(x_,y_,z_,pitch_,yaw_,roll_):
            x = float(x)
            y = float(y)
            z = float(z)
            pitch = float(pitch)
            yaw = float(yaw)
            roll = float(roll)
            str_tmp = '(\"agent#1\",%.5e,%.5e,%.5e,%.5e,%.5e,%.5e)' %(x,y,z,pitch,yaw,roll)
            str_list.append(str_tmp)
            str_list.append(';')
        str_list.append('\n')
        
        cmd = ''.join(str_list)
        self.send(bytes(cmd, encoding = 'utf8'))

    def other_cmd(self,*args):
        func_name = traceback.extract_stack()[-2][2]
        
        strlist = ['>>',func_name,'(']
        for _i_ in range(len(args)):
            if isinstance(args[_i_],int):
                if self.digit == 8 : strlist.append("%.8e"%args[_i_])
                elif self.digit == 6 : strlist.append("%.6e"%args[_i_])
                elif self.digit == 4 : strlist.append("%.4e"%args[_i_])
                strlist.append(",")
            elif isinstance(args[_i_],float):
                if self.digit == 8 : strlist.append("%.8e"%args[_i_])
                elif self.digit == 6 : strlist.append("%.6e"%args[_i_])
                elif self.digit == 4 : strlist.append("%.4e"%args[_i_])
                strlist.append(",")
            elif isinstance(args[_i_],str):
                strlist.append("\'")
                strlist.append(args[_i_])
                strlist.append("\'")
                strlist.append(",")
            elif isinstance(args[_i_],list):
                strlist.append(str(args[_i_]))
                strlist.append(",")
            elif isinstance(args[_i_],np.ndarray):
                if args[0].ndim == 1:
                    sub_list = ["["]   +   ["%.3e "%t for t in args[0]]  +  ["]"]
                    strlist += sub_list
                    strlist.append(",")
                elif args[0].ndim == 2:
                    print红('mcom：输入数组的维度大于1维，目前处理不了。')
                else:
                    print红('mcom：输入数组的维度大于2维，目前处理不了。')
            else:
                print('error building cmd')
        if strlist[len(strlist)-1] == "(":
            strlist.append(")\n")
        else:   # 把逗号换成后括号
            strlist[len(strlist)-1] = ")\n"
        cmd = ''.join(strlist)
        self.send(bytes(cmd, encoding = 'utf8'))
        
    exec('def plot(self,*args):\n  self.other_cmd(*args)\n')
    exec('def figure(self,*args):\n  self.other_cmd(*args)\n')
    exec('def hold(self,*args):\n  self.other_cmd(*args)\n')
    exec('def box(self,*args):\n  self.other_cmd(*args)\n')
    exec('def pause(self,*args):\n  self.other_cmd(*args)\n')
    exec('def clf(self,*args):\n  self.other_cmd(*args)\n')
    exec('def xlim(self,*args):\n  self.other_cmd(*args)\n')
    exec('def ylim(self,*args):\n  self.other_cmd(*args)\n')
    exec('def xlabel(self,*args):\n  self.other_cmd(*args)\n')
    exec('def ylabel(self,*args):\n  self.other_cmd(*args)\n')
    exec('def drawnow(self,*args):\n  self.other_cmd(*args)\n')
    exec('def v2d(self,*args):\n  self.other_cmd(*args)\n')
    exec('def v2d_init(self,*args):\n  self.other_cmd(*args)\n')
    exec('def v2L(self,*args):\n  self.other_cmd(*args)\n')
    exec('def title(self,*args):\n  self.other_cmd(*args)\n')
    exec('def plot3(self,*args):\n  self.other_cmd(*args)\n')
    exec('def grid(self,*args):\n  self.other_cmd(*args)\n')
    
    
    def __del__(self):
        # on the end of the program
        if self.current_file_handle is not None: 
            end_file_flag = (b'><EndTaskFlag\n')
            self.current_file_handle.write(end_file_flag)
            self.current_file_handle.close()
        if self.port is not None:
            self.disconnect()
        print蓝('the program exited, mCOMv5 as well exited!')



    def disconnect(self):
        self.socketx.close()


# if __name__ == "__main__":
#     mcv = mCOMv5(ip='127.0.0.1', port=12084, path='./mcomv5/', digit=8, rapid_flush=True)
#     mcv.rec_init()
#     for y in range(100):
#         for x in range(1000):
#             mcv.rec(x,'rewardx')
#         mcv.rec_show()
#         time.sleep(1)
            
#     pass